import numpy as np

cifar10_10 = [[0.884,0.026,0.004,0.012,0.001,0.001,0.007,0.,0.061,0.004],
                        [0.003,0.983,0.,0.,0.001,0.,0.001,0.,0.005,0.007],
                        [0.158,0.018,0.612,0.106,0.033,0.016,0.05,0.004,0.003,0.],
                        [0.034,0.048,0.038,0.745,0.028,0.053,0.04,0.005,0.007,0.002],
                        [0.029,0.012,0.092,0.111,0.677,0.007,0.056,0.009,0.007,0.],
                        [0.024,0.035,0.058,0.277,0.033,0.54,0.014,0.016,0.001,0.002],
                        [0.021,0.015,0.05,0.151,0.012,0.004,0.737,0.003,0.005,0.002],
                        [0.113,0.047,0.047,0.123,0.101,0.029,0.012,0.512,0.005,0.011],
                        [0.161,0.077,0.004,0.012,0.002,0.002,0.005,0.001,0.729,0.007],
                        [0.05,0.366,0.005,0.016,0.,0.,0.003,0.001,0.029,0.53]]


cifar10_50 = [[0.965, 0.005, 0.015, 0.   , 0.003, 0.002, 0.002, 0.   , 0.007,
        0.001],
       [0.008, 0.984, 0.004, 0.   , 0.   , 0.   , 0.001, 0.   , 0.001,
        0.002],
       [0.059, 0.001, 0.853, 0.03 , 0.026, 0.015, 0.011, 0.004, 0.   ,
        0.001],
       [0.032, 0.008, 0.057, 0.776, 0.03 , 0.066, 0.018, 0.007, 0.003,
        0.003],
       [0.022, 0.001, 0.063, 0.045, 0.833, 0.012, 0.009, 0.014, 0.001,
        0.   ],
       [0.024, 0.001, 0.058, 0.173, 0.037, 0.692, 0.005, 0.009, 0.   ,
        0.001],
       [0.025, 0.008, 0.102, 0.09 , 0.02 , 0.01 , 0.744, 0.001, 0.   ,
        0.   ],
       [0.063, 0.005, 0.038, 0.076, 0.079, 0.052, 0.007, 0.676, 0.001,
        0.003],
       [0.262, 0.052, 0.008, 0.017, 0.003, 0.   , 0.004, 0.   , 0.644,
        0.01 ],
       [0.112, 0.256, 0.009, 0.011, 0.001, 0.001, 0.001, 0.001, 0.005,
        0.603]]


cifar10_100 = [[0.884,0.026,0.004,0.012,0.001,0.001,0.007,0.,0.061,0.004],
                [0.003,0.983,0.,0.,0.001,0.,0.001,0.,0.005,0.007],
                [0.158,0.018,0.612,0.106,0.033,0.016,0.05,0.004,0.003,0.],
                [0.034,0.048,0.038,0.745,0.028,0.053,0.04,0.005,0.007,0.002],
                [0.029,0.012,0.092,0.111,0.677,0.007,0.056,0.009,0.007,0.],
                [0.024,0.035,0.058,0.277,0.033,0.54,0.014,0.016,0.001,0.002],
                [0.021,0.015,0.05,0.151,0.012,0.004,0.737,0.003,0.005,0.002],
                [0.113,0.047,0.047,0.123,0.101,0.029,0.012,0.512,0.005,0.011],
                [0.161,0.077,0.004,0.012,0.002,0.002,0.005,0.001,0.729,0.007],
                [0.05,0.366,0.005,0.016,0.,0.,0.003,0.001,0.029,0.53]]

cifar10_200 = [[0.941, 0.016, 0.006, 0.003, 0.007, 0.   , 0.001, 0.001, 0.025, 0.   ],
                [0.013, 0.984, 0.   , 0.001, 0.   , 0.   , 0.001, 0.   , 0.   , 0.001],
                [0.104, 0.006, 0.732, 0.049, 0.041, 0.013, 0.041, 0.005, 0.006, 0.003],
                [0.051, 0.018, 0.079, 0.725, 0.051, 0.035, 0.025, 0.008, 0.007, 0.001],
                [0.046, 0.003, 0.07 , 0.069, 0.738, 0.006, 0.024, 0.033, 0.009, 0.002],
                [0.037, 0.008, 0.119, 0.311, 0.052, 0.445, 0.006, 0.019, 0.002, 0.001],
                [0.034, 0.01 , 0.085, 0.129, 0.025, 0.007, 0.703, 0.002, 0.005, 0.   ],
                [0.095, 0.01 , 0.063, 0.103, 0.117, 0.05 , 0.006, 0.548, 0.004, 0.004],
                [0.266, 0.106, 0.013, 0.008, 0.003, 0.   , 0.005, 0.002, 0.586, 0.011],
                [0.162, 0.518, 0.003, 0.013, 0.006, 0.   , 0.002, 0.001, 0.024, 0.271]]

def obtain_confusion_matrix(dataset_if):
    all_confusion_matrix = {
        'cifar10_10': cifar10_10,
        'cifar10_50': cifar10_50,
        'cifar10_100': cifar10_100,
        'cifar10_200': cifar10_200,
    }

    return all_confusion_matrix[dataset_if]